from tkinter.font import names

import numpy as np
from torch import nn
from torch.optim import Optimizer, SGD, Adam, LBFGS, Adamax, Adadelta, NAdam
import torch

import torch.nn.functional as F
from scipy.optimize import fsolve
from utils import kernel_toeplitz, data_toeplitz, model_require, pooling_toeplitz, power_iteration, cnn_power_iterations, normalization


def relu_prox(prox_a, prox_b, gamma):
    val = torch.zeros_like(prox_b)
    zero = torch.zeros_like(prox_b)
    x = (prox_a + gamma * prox_b) / (1 + gamma)
    y = torch.min(prox_b, zero)

    val = torch.where(prox_a + gamma * prox_b < 0, y, zero)
    val = torch.where(
        ((prox_a + gamma * prox_b >= 0) & (prox_b >= 0)) | (
                (prox_a * (gamma - torch.sqrt(gamma * (gamma + 1))) <= gamma * prox_b) & (prox_b < 0)), x,
        val)
    val = torch.where(
        (-prox_a <= gamma * prox_b) & (gamma * prox_b <= prox_a * (gamma - torch.sqrt(gamma * (gamma + 1)))),
        prox_b, val)
    return val


class DPGradBCD(Optimizer):
    def __init__(self, params, layers, batch_size, device, conv=2, lr=1, gamma=None, lossf="mse", alpha=None, rho=None,
                 ns=None, ne=None,
                 lips=False, x_lips=False, epochs=0, shift=0, batch_num=0, input_dim=(3, 32, 32),xd_threshold=1.0):
        super(DPGradBCD, self).__init__(params, defaults={'lr': lr})
        self.params_dict = {}
        self.layer = layers
        self.last_layer = layers
        self.first_layer = 1
        self.params_dict_new = {}
        self.lr = lr
        self.batch_size = batch_size
        self.device = device
        self.check = True
        self.count = 0
        self.lipschitz = lips
        self.batch_num = batch_num
        if self.lipschitz > 0:
            self.w_lip = self.lipschitz

        self.lip_enable_x = x_lips
        self.x_lip = 0.99
        self.gradient = True
        self.optimizer_list = []
        self.loss = 10000
        self.shift = shift
        self.ns = ns
        self.ne = ne
        self.epochs = epochs
        self.conv = conv
        if ns == 0:
            self.dp = False
        else:
            self.dp = True

        self.lossf = lossf
        self.input_dim = input_dim
        self.x_bound = xd_threshold

        self.noise_var = []
        if ne != -1:
            for i in range(layers):
                self.noise_var.append(torch.linspace(ns, ne, (epochs) * batch_num))

        self.rho_lr = torch.ones(layers + 1) * lr

        if gamma is None:
            self.gamma = torch.ones(layers + 1)
        else:
            self.gamma = torch.ones(layers + 1) * gamma

        if rho is None:
            self.rho = torch.ones(layers + 1)
        else:
            self.rho = torch.ones(layers + 1) * rho

        self.alpha = torch.ones(layers + 1) * alpha
        self.param_name_format = ["conv{}.{}"] * (self.conv + 1) + ["linear{}.{}"] * (self.last_layer - self.conv + 1)
        self.aux_name_format = ["conv{}_{}"] * (self.conv + 1) + ["linear{}_{}"] * (self.last_layer - self.conv + 1)

        for param in params:
            self.params_dict[param["name"]] = param["params"][0]
            self.params_dict_new[param["name"]] = (param["params"][0]).clone().detach()

        self.aux_init(self.batch_size)

        aux_name_format = "linear{}_{}"
        param_name_format = "linear{}.{}"

        aux_name_format_cnn = "conv{}_{}"
        param_name_format_cnn = "conv{}.{}"

        # for i in range(1, self.layer + 1):
        #     if i <= self.conv:
        #         w = self.params_dict_new[aux_name_format_cnn.format(i, "weight")]
        #         b = self.params_dict_new[param_name_format_cnn.format(i, "bias")]
        #     else:
        #         w = self.params_dict_new[param_name_format.format(i, "weight")]
        #         b = self.params_dict_new[param_name_format.format(i, "bias")]

        # optimizer = SGD([
        #     {"params": w, "weight_decay": 0},
        #     {"params": b, "weight_decay": 0},
        # ], lr=self.lr, momentum=0.01, nesterov=True)

        # self.optimizer_list.append(optimizer)

        self.output_name = aux_name_format.format(self.last_layer, "x")

    def _product(self, weight, x, b, prod_type="linear", layer=None):
        if prod_type == "conv":
            if layer is not None and layer == 1:
                return F.conv2d(x, weight, bias=b, stride=1, padding=0)
            else:
                return F.conv2d(F.avg_pool2d(x, kernel_size=2, stride=2), weight, bias=b, stride=1, padding=0)
        else:
            return x @ weight.T + b

    def para_reader(self, layer, para, format="tensor"):

        if format == "tensor":
            if para == "x" or para == "u":
                name = self.aux_name_format[layer].format(layer, para)
                return name, self.params_dict_new[name]
            else:
                name = self.param_name_format[layer].format(layer, para)
                return name, self.params_dict_new[name]

    def para_writer(self, layer, para, data):
        if para == "x" or para == "u":
            self.params_dict_new[self.aux_name_format[layer].format(layer, para)].copy_(data)
        else:
            self.params_dict_new[self.param_name_format[layer].format(layer, para)].copy_(data)

    def aux_init(self, batch_size):
        aux_name_format = "linear{}_{}"
        param_name_format = "linear{}.{}"
        aux_name_format_cnn = "conv{}_{}"
        param_name_format_cnn = "conv{}.{}"
        for i in range(1, self.layer + 1):
            if param_name_format.format(i, "weight") in self.params_dict:
                # linear
                dim = self.params_dict[param_name_format.format(i, "weight")].size()[0]
                aux_init = torch.empty([batch_size, dim], device=self.device) * 0
                self.params_dict[aux_name_format.format(i, "x")] = aux_init.clone().detach()
                self.params_dict[aux_name_format.format(i, "u")] = nn.ReLU()(aux_init).clone().detach()
                self.params_dict_new[aux_name_format.format(i, "x")] = aux_init.clone().detach()
                self.params_dict_new[aux_name_format.format(i, "u")] = nn.ReLU()(aux_init).clone().detach()
            else:
                # cnn
                weight = self.params_dict[param_name_format_cnn.format(i, "weight")]

                # the first layer
                if i == 1:
                    data = torch.randn((batch_size, self.input_dim[0], self.input_dim[1], self.input_dim[2]),
                                       device=self.device)
                    x = F.conv2d(data, weight)
                else:
                    data = self.params_dict[aux_name_format_cnn.format(i - 1, "u")]
                    pooling = model_require(None, i)
                    x = F.conv2d(pooling(data), weight)

                self.params_dict[aux_name_format_cnn.format(i, "x")] = x.clone().detach()
                self.params_dict[aux_name_format_cnn.format(i, "u")] = (nn.ReLU()(x).clone().detach())
                self.params_dict_new[aux_name_format_cnn.format(i, "x")] = x.clone().detach()
                self.params_dict_new[aux_name_format_cnn.format(i, "u")] = (nn.ReLU()(x).clone().detach())

    def _x_update_last(self, w, b, u_d, x_d, gamma, alpha, name):
        self.params_dict_new[name].copy_(
            (self.params_dict["label"] + gamma * u_d + alpha * x_d) / (1 + gamma + alpha))  # with weight decay

    def _x_update_last_cross(self, w, b, u_d, x_d, gamma, alpha, name):

        def func(x, y, U, gamma):
            _, labels = y.max(dim=1)
            return F.cross_entropy(x, labels) + gamma/2 * torch.linalg.matrix_norm(x-F.relu(U))

        # def func(x, y, U, gamma):
        #     ex = torch.exp(x)
        #     sum_ex = torch.sum(ex, dim=0)
        #     return ex / sum_ex - y + gamma/2 * torch.linalg.matrix_norm(x-F.relu(U))

        x = torch.zeros_like(self.params_dict["label"], requires_grad=True)
        loss = func(x, self.params_dict["label"], u_d, gamma)
        # print(f"ce start loss = {loss}")

        # if False:
        #     optimizer = torch.optim.LBFGS([x], lr=0.0001)
        #     def closure():
        #         optimizer.zero_grad()
        #         loss = func(x, self.params_dict["label"], u_d, gamma)
        #         loss.backward()
        #         return loss
        #
        #     for _ in range(1000):
        #         optimizer.step(closure)
        #     x.requires_grad = False
        #     loss = func(x, self.params_dict["label"], u_d, gamma)
        #     print(f"ce end loss = {loss}")
        # else:
        optimizer = torch.optim.AdamW([x], lr=0.0001)
        for _ in range(1000):
            optimizer.zero_grad()
            loss = func(x, self.params_dict["label"], u_d, gamma)
            loss.backward()
            optimizer.step()
        x.requires_grad = False
        loss = func(x, self.params_dict["label"], u_d, gamma)
        # print(f"ce end loss = {loss}")
        self.params_dict_new[name].copy_(x.detach())

    def _u_update_last(self, w, b, x_d, x_d_1, gamma, rho, name):
        self.params_dict_new[name].copy_((gamma * x_d + rho * (x_d_1 @ w.T + b)) / (rho + gamma))

    def _x_update(self, w, b, u_d, u_d_a1, gamma, rho, name, x=None, layer=None):
        if layer < self.conv:
            self._x_cnn_update(w, b, u_d, u_d_a1, gamma, rho, name, x=x, layer=layer)
            return
        if layer == self.conv:
            self._x_cnn_update(w, b, u_d, u_d_a1, gamma, rho, name, x=x, layer=layer)
            return
        else:
            I = torch.eye(w.shape[1]).to(self.device)
            u_d = nn.ReLU()(u_d)
            x_star = (torch.inverse(rho * w.T @ w + gamma * I) @ (rho * (w.T @ (u_d_a1 - b).T) + gamma * u_d.T)).T
            x_star = normalization(x_star,device=self.device,threshold=self.x_bound)
            self.params_dict_new[name].copy_(x_star)
            return

    def _x_cnn_update(self, w_a1, b_a1, u_d, u_d_a1, gamma, rho, name, x=None, layer=None):
        # get the pooling parameter
        if x is None:
            raise (ValueError, "x cannot be None")
        (batch, in_channels, x_w, x_h) = x.shape

        pooling_stride = 2
        kernel_size = (2, 2)
        pooling_para = pooling_toeplitz(x, kernel_size=kernel_size, stride=pooling_stride, device=self.device)
        # [pooling_output, pooling_input] = pooling_para.shape
        # u_d_vector = u_d_a1.reshape(batch, -1)

        # shape after pooling
        x_pooling = x_w // pooling_stride
        if len(w_a1.shape) != 2:
            # get the theta_matrix with pooling
            (_, out_channels, out_w, out_h) = u_d_a1.shape
            conv_stride = 1
            theta_matrix = kernel_toeplitz(w_a1, (batch, in_channels, x_pooling, x_pooling), stride=conv_stride,
                                           device=self.device)
            (out_channel_k, in_channel_k, w_w, w_h) = w_a1.shape
            u_d = nn.ReLU()(u_d.clone().flatten(start_dim=1))
            I = torch.eye(u_d.shape[1]).to(device=self.device)
            pooling_x_vector = (theta_matrix @ pooling_para).reshape(out_channels * out_w * out_h, -1).to(
                device=self.device)
            u_b = (u_d_a1 - b_a1.unsqueeze(1).unsqueeze(2).expand(-1, 10, 10)).flatten(start_dim=1)  # u_d_a1 - b_a1
            x_star = (torch.inverse(pooling_x_vector.T @ pooling_x_vector + I) @ (
                    u_d.T + pooling_x_vector.T @ u_b.T))
            return normalization(x_star.T.reshape(batch, in_channels, x_w, x_h),device=self.device,threshold=self.x_bound)

        else:
            theta_matrix = w_a1
            u_d = nn.ReLU()(u_d.clone().flatten(start_dim=1))
            I = torch.eye(u_d.shape[1]).to(self.device)
            pooling_x_vector = (theta_matrix @ pooling_para).to(self.device)

            x_star = (torch.inverse(pooling_x_vector.T @ pooling_x_vector + I) @ (
                    u_d.T + pooling_x_vector.T @ (u_d_a1 - b_a1).T))
            return normalization( x_star.T.reshape(batch, in_channels, x_w, x_h),device=self.device,threshold=self.x_bound ) # todo: correct for pooling!

    def _u_update(self, w, b, x_d, x_d_1, u_d, layer,
                  name):

        rho = self.rho[layer]
        alpha = self.alpha[layer]
        gamma = self.gamma[layer]

        # flatten the vector in the specific layer
        if layer <= self.conv:
            u_1 = self._product(w, x_d_1, b, "conv", layer=layer)
        elif layer == self.conv + 1:
            pooling = model_require(None, layer)
            x_d_tmp = pooling(x_d_1.clone()).flatten(start_dim=1)
            u_1 = self._product(w, x_d_tmp, b, "linear")
        else:
            u_1 = self._product(w, x_d_1, b, "linear")

        prox_b = (rho * u_1 + alpha * u_d) / (rho + alpha)
        # prox_b = (rho * (x_d_1 @ w.T + b) + alpha * u_d) / (rho + alpha)
        gamma_hat = (rho + alpha) / gamma
        self.params_dict_new[name].copy_(relu_prox(x_d, prox_b, gamma_hat))

    def _auxiliary_update(self, layer):
        param_name_format = ["conv{}.{}"] * (self.conv + 1) + ["linear{}.{}"] * (self.last_layer - self.conv + 1)
        aux_name_format = ["conv{}_{}"] * (self.conv + 1) + ["linear{}_{}"] * (self.last_layer - self.conv + 1)
        gamma = self.gamma[self.layer]  # this layer's gamma
        rho = self.rho[self.layer]  # this layer's rho

        # x_update
        if layer is self.last_layer:

            w = self.params_dict_new[param_name_format[layer].format(layer, "weight")]
            b = self.params_dict_new[param_name_format[layer].format(layer, "bias")]

            # update x
            u_d = self.params_dict_new[aux_name_format[layer].format(layer, "u")]
            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]
            if self.lossf == "mse":
                self._x_update_last(w, b, u_d, x_d, gamma, self.alpha[self.last_layer],
                                    aux_name_format[layer].format(layer, "x"))
            if self.lossf == "ce":
                self._x_update_last_cross(w, b, u_d, x_d, gamma, self.alpha[self.last_layer],
                                          aux_name_format[layer].format(layer, "x"))

            # update u
            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]  # new here each time
            x_d_1 = self.params_dict_new[aux_name_format[layer - 1].format(layer - 1, "x")]
            self._u_update_last(w, b, x_d, x_d_1, gamma, rho, aux_name_format[layer].format(layer, "u"))

        elif layer == self.first_layer:

            # x update
            w_a1 = self.params_dict_new[param_name_format[layer + 1].format(layer + 1, "weight")]
            b_a1 = self.params_dict_new[param_name_format[layer + 1].format(layer + 1, "bias")]
            u_d_a1 = self.params_dict_new[aux_name_format[layer + 1].format(layer + 1, "u")]
            u_d = self.params_dict_new[aux_name_format[layer].format(layer, "u")]
            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]

            self._x_update(w_a1, b_a1, u_d, u_d_a1, gamma, rho, aux_name_format[layer].format(layer, "x"), layer=layer,
                           x=x_d)

            # u update
            w = self.params_dict_new[param_name_format[layer].format(layer, "weight")]
            b = self.params_dict_new[param_name_format[layer].format(layer, "bias")]

            x_d_1 = self.params_dict["data"]  # training data

            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]
            u_d = self.params_dict_new[aux_name_format[layer].format(layer, "u")]
            self._u_update(w, b, x_d, x_d_1, u_d, layer, aux_name_format[layer].format(layer, "u"))

        else:

            # x update
            w_a1 = self.params_dict_new[param_name_format[layer + 1].format(layer + 1, "weight")]
            b_a1 = self.params_dict_new[param_name_format[layer + 1].format(layer + 1, "bias")]
            u_d_a1 = self.params_dict_new[aux_name_format[layer + 1].format(layer + 1, "u")]
            u_d = self.params_dict_new[aux_name_format[layer].format(layer, "u")]
            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]

            self._x_update(w_a1, b_a1, u_d, u_d_a1, gamma, rho, aux_name_format[layer].format(layer, "x"), layer=layer,
                           x=x_d)

            # u update
            w = self.params_dict_new[param_name_format[layer].format(layer, "weight")]
            b = self.params_dict_new[param_name_format[layer].format(layer, "bias")]

            x_d_1 = self.params_dict_new[aux_name_format[layer - 1].format(layer - 1, "x")]
            x_d = self.params_dict_new[aux_name_format[layer].format(layer, "x")]
            self._u_update(w, b, x_d, x_d_1, u_d, layer, aux_name_format[layer].format(layer, "u"))
        pass

    def _theta_update(self, w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer):
        if layer <= self.conv:
            self._theta_cnn_update(w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer)
            return
        if layer == self.conv + 1:
            x_d_1 = torch.nn.functional.avg_pool2d(x_d_1, kernel_size=2, stride=2).flatten(
                start_dim=1)  # todo: require parameters

        I = torch.eye(w.shape[1]).to(self.device)
        w_bak = w.clone()
        if self.dp:
            w_noise = self.noise_generator(layer, w.size())
            w_star = (alpha * w + rho * ((u_d - b).T @ x_d_1) + torch.sqrt(2 * rho) * w_noise) @ torch.inverse(
                alpha * I + rho * x_d_1.T @ x_d_1)

            b_noise = self.noise_generator(layer, b.size())
            b_star = (rho * (torch.sum(u_d - (x_d_1 @ w_bak.T), dim=0)) + alpha * b + torch.sqrt(2 * rho) * b_noise) / (
                    rho * self.batch_size + alpha)

        else:
            w_star = (alpha * w + rho * ((u_d - b).T @ x_d_1)) @ torch.inverse(alpha * I + rho * x_d_1.T @ x_d_1)
            b_star = (rho * (torch.sum(u_d - (x_d_1 @ w_bak.T), dim=0)) + alpha * b) / (rho * self.batch_size + alpha)

        self.params_dict_new[w_name].copy_(w_star)
        self.params_dict_new[b_name].copy_(b_star)

    def _theta_cnn_update(self, w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer):
        device = self.device
        (batch_size, out_channel, out_w, out_h) = u_d.shape
        (out_channel_k, in_channel_k, w_w, w_h) = w.shape
        w_vector = w.clone().flatten(start_dim=1)
        I = torch.eye(w_vector.shape[1]).to(device)
        u_d_bak = u_d.clone().transpose(0, 1).reshape(out_channel, -1).T

        if layer != 1:
            x_d_1_bak = data_toeplitz(torch.nn.functional.avg_pool2d(x_d_1, 2, stride=2, padding=0), w, stride=1)
            x_d_1_bak = x_d_1_bak.reshape(-1, in_channel_k * w_w * w_h)
        else:
            x_d_1_bak = data_toeplitz(x_d_1, w, stride=1)
            x_d_1_bak = x_d_1_bak.reshape(-1, in_channel_k * w_w * w_h)

        if self.dp:
            w_noise = self.noise_generator(layer, w_vector.size())
            w_star = ((w_vector + rho * (1 / alpha) * ((u_d_bak - b).T @ x_d_1_bak) + torch.sqrt(
                2 * rho) * w_noise) @ torch.inverse(
                I + rho * (1 / alpha) * x_d_1_bak.T @ x_d_1_bak))
            b_star = (rho * (1 / alpha) * (torch.sum(u_d_bak - (x_d_1_bak @ w_star.T), dim=0)) + b) / (
                    rho * (1 / alpha) * batch_size * out_h * out_w)
            w_star = w_star.reshape((out_channel_k, in_channel_k, w_w, w_h))

            # w_star = (alpha * w_bak + rho * ((u_d_bak - b).T @ x_d_1_bak) + torch.sqrt(2*rho)*w_noise ) @ torch.inverse(
            #     alpha * I + rho * x_d_1_bak.T @ x_d_1_bak)


        else:
            w_star = ((w_vector + rho * (1 / alpha) * ((u_d_bak - b).T @ x_d_1_bak)) @ torch.inverse(
                I + rho * (1 / alpha) * x_d_1_bak.T @ x_d_1_bak))
            b_star = (rho * (1 / alpha) * (torch.sum(u_d_bak - (x_d_1_bak @ w_vector.T), dim=0)) + b) / (
                    rho * (1 / alpha) * batch_size * out_h * out_w)
            w_star = w_star.reshape((out_channel_k, in_channel_k, w_w, w_h))

        # if self.dp:
        #     w_noise = self.noise_generator(layer, w_bak.size())
        #     w_star = (alpha * w + rho * ((u_d - b).T @ x_d_1) + torch.sqrt(2*rho)*w_noise ) @ torch.inverse(
        #         alpha * I + rho * x_d_1.T @ x_d_1)
        #
        #     b_noise = self.noise_generator(layer, b.size())
        #     b_star = (rho * (torch.sum(u_d - (x_d_1 @ w_bak.T), dim=0)) + alpha * b + torch.sqrt(2*rho)*b_noise) / (rho * self.batch_size + alpha)
        #
        # else:
        #     w_star = (w + rho * (1/alpha) * ((u_d - b).T @ x_d_1)) @ torch.inverse(I + rho * (1/alpha) * x_d_1.T @ x_d_1)
        #     b_star = (rho * (1/alpha) * (torch.sum(u_d - (x_d_1 @ w_bak.T), dim=0)) + b) / (rho * (1/alpha) * self.batch_size)

        self.params_dict_new[w_name].copy_(w_star)
        self.params_dict_new[b_name].copy_(b_star)

    def _theta_manual_gradient_update(self, w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer=None):
        w_bak = w.clone()
        b_bak = b.clone()

        w_grad = (u_d - x_d_1 @ w_bak.T - b_bak).T @ (-x_d_1) / self.batch_size
        b_grad = ((u_d - x_d_1 @ w_bak.T - b_bak).T * (-1)).sum(dim=1) / self.batch_size

        w_star = w - self.lr * w_grad
        b_star = b - self.lr * b_grad
        self.params_dict_new[w_name].copy_(w_star)
        self.params_dict_new[b_name].copy_(b_star)

    def grad(self, w, x):
        conv_out = F.conv2d(x, w)
        grad_output = torch.ones_like(conv_out)
        grad_a = F.conv2d(x.transpose(0, 1), grad_output.transpose(0, 1))

    def _theta_fast_gradient_update(self, w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer=None):

        w_bak = w.clone()
        b_bak = b.clone()
        I = torch.eye(w.shape[1]).to(self.device)
        if self.lipschitz != 0:
            combined = torch.hstack((w_bak, b_bak.unsqueeze(1)))
            factor = power_iteration(weight=combined, n_iter=3, eps=1e-12, coeff=self.w_lip, matrix=False,
                                     name=w_name)
            w_bak = w_bak / factor
            b_bak = b_bak / factor

        w_bak.requires_grad = True
        b_bak.requires_grad = True

        y = rho / 2 * torch.pow(torch.norm((u_d - (x_d_1 @ w_bak.T + b_bak)), p='fro'), 2)
        y.backward()

        #
        w_star = ((w_bak @ (I + self.lr * x_d_1.T @ x_d_1) - self.lr * w_bak.grad) @
                  torch.inverse(I + rho * self.lr * x_d_1.T @ x_d_1))

        # w_test = (w + rho * self.lr * ((u_d - b).T @ x_d_1)) @ torch.inverse(I + rho * self.lr * x_d_1.T @ x_d_1)
        # print(torch.norm(w_star - w_test))

        b_star = (rho * self.lr * (torch.sum(u_d - (x_d_1 @ w_bak.T), dim=0)) + b) / (
                rho * self.lr * self.batch_size)
        # b_star = (b_bak * (1 + self.lr) - self.lr * b_bak.grad) / (rho * self.lr * self.batch_size)
        w_bak.grad = None
        b_bak.grad = None
        w_bak = w_bak.detach()
        b_bak = b_star.detach()
        w_bak.requires_grad = False
        b_bak.requires_grad = False

        self.params_dict_new[w_name].copy_(w_star)
        self.params_dict_new[b_name].copy_(b_star)

    def _lipschitz_theta(self, w, b, w_name, b_name, layer=0, input_dim=None):

        if layer <= self.conv:
            weight = cnn_power_iterations(n_power_iterations=3, weight=w, stride=1, epsilon=1e-12, input_dim=input_dim,
                                          coeff=self.w_lip,device=self.device)

        else:
            factor = power_iteration(weight=w, n_iter=3, eps=1e-12, coeff=self.w_lip, matrix=False,
                                     name=w_name)
            self.params_dict_new[w_name].copy_(w / factor)

    def _theta_gradient_update(self, w, b, x_d_1, u_d, alpha, rho, w_name, b_name, layer=None):
        w.requires_grad = True
        b.requires_grad = True
        optimizer = self.optimizer_list[layer - 1]
        optimizer.lr = 1 / self.alpha[layer - 1]

        for t in range(5):
            optimizer.zero_grad()
            y = rho / 2 * torch.pow(torch.norm((u_d - (x_d_1 @ w.T + b)), p='fro'), 2)
            y.backward()
            # print("layer_name {}, norm of grad is {}".format(w_name, torch.norm(w.grad, "fro")))
            optimizer.step()
        w.requires_grad = False
        b.requires_grad = False
        if self.dp:
            w += w - self.lr * self.noise_generator(layer, w.size())

        # if self.dp:
        #     noise_w = self.noise_generator(layer, w.size())
        #     noise_b = self.noise_generator(layer, b.size())
        #     w.add_(- self.lr * noise_w)
        #     b.add_(- self.lr * noise_b)

    def _param_update(self, layer):

        param_name_format = "linear{}.{}"
        aux_name_format = "linear{}_{}"
        w_name, w = self.para_reader(layer, "weight")
        b_name, b = self.para_reader(layer, "bias")
        gamma = self.gamma[self.layer]  # this layer's gamma
        rho = self.rho[self.layer]  # this layer's rho
        alpha = self.alpha[self.layer]
        rho_lr = self.rho_lr[self.layer]

        if layer == self.first_layer:
            x_d_1 = self.params_dict["data"]
        else:
            _, x_d_1 = self.para_reader(layer - 1, "x")
        _, u_d = self.para_reader(layer, "u")
        try:
            self._theta_update(w, b, x_d_1, u_d, alpha, rho_lr, w_name, b_name, layer)
        except:
            print("Error updating theta in layer{}".format(layer))
            raise
        if self.lipschitz != 0:
            self._lipschitz_theta(self.params_dict_new[w_name], self.params_dict_new[b_name], w_name, b_name,
                                  layer=layer, input_dim=x_d_1.shape)

    def _upload_parameters(self):
        if self.check:
            for layer_name, data in self.params_dict.items():
                if self.params_dict_new[layer_name].size() != data.size():
                    print("wrong in {}".format(layer_name))
                else:
                    # self.params_dict[layer_name].requires_grad = False
                    self.params_dict[layer_name].copy_(self.params_dict_new[layer_name])

        output = self.params_dict_new[self.output_name].to(self.device)
        if self.lossf == "mse":
            loss = torch.pow(torch.dist(output, self.params_dict["label"], 2), 2)
        if self.lossf == "ce":
            loss = torch.nn.functional.cross_entropy(output, self.params_dict["label"])
        print("loss: {}".format(loss))
        self.loss = loss

    def _update(self):
        for layer in reversed(range(1, self.layer + 1)):
            self._auxiliary_update(layer)
            self._param_update(layer)
        self._upload_parameters()

    def aux_double_init(self):

        # init weight
        for i in range(1, self.layer + 1):
            name, parameters = self.para_reader(i, "weight", format="tensor")
            size = parameters.size()
            w_init = 0.001 * torch.randn(size, device=self.device)
            nn.init.kaiming_uniform_(w_init, mode='fan_in', nonlinearity='relu')
            self.para_writer(i, "weight", w_init)
            _, b = self.para_reader(i, "bias", format="tensor")
            self.para_writer(i, "bias", 0.001 * b )

            # self.params_dict_new[param_name_format.format(i, "weight")].copy_(w_init)
            # self.params_dict_new[param_name_format.format(i, "bias")].copy_(
            #     )

        # init aux
        for i in range(1, self.layer + 1):
            _, w = self.para_reader(i, "weight", format="tensor")
            _, b = self.para_reader(i, "bias", format="tensor")
            if i == 1:
                x_d_1 = self.params_dict["data"]
            else:
                _, x_d_1 = self.para_reader(i-1, "x", format="tensor")

            if i <= self.conv:
                u_d = self._product(w, x_d_1, b, "conv", layer=i)
            elif i == self.conv + 1:
                pooling = model_require(None, i)
                x_d_tmp = pooling(x_d_1.clone()).flatten(start_dim=1)
                u_d = self._product(w, x_d_tmp, b, "linear")
            else:
                u_d = self._product(w, x_d_1, b, "linear")

            x_d = nn.ReLU()(u_d)
            self.para_writer(i, "x", x_d)
            self.para_writer(i, "u", u_d)
            # self.params_dict_new[aux_name_format.format(i, "x")].copy_(x_d)
            # self.params_dict_new[aux_name_format.format(i, "u")].copy_(u_d)

    def noise_generator(self, layer, size):
        noise_var = self.noise_var[layer - 1][self.count]
        return torch.normal(0, noise_var, size, device=self.device)

    def step(self, *args, **kwargs) -> None:
        closure = args[0]
        self.params_dict["data"], self.params_dict["label"] = closure()

        self.params_dict["data"] = normalization(self.params_dict["data"],device=self.device,threshold=self.x_bound)

        self.params_dict_new["data"] = self.params_dict["data"]

        self.params_dict_new["label"] = self.params_dict["label"]
        self.gradient = False
        if self.count == 0:
            self.aux_double_init()

        self._update()
        self.count += 1
